import copy
import sys
from collections import OrderedDict

import torch
import torch.nn.utils.prune as prune
from torch.cuda.amp import autocast

from optimizers.operators import ProximalOperator
from torchmetrics import MeanMetric
from utilities.utilities import Utilities as Utils
from math import sqrt

#### Dense Base Class
class Dense:
    """Dense base class for defining callbacks, does nothing but showing the structure and inherits."""
    required_params = []

    def __init__(self, **kwargs):
        self.masks = dict()
        self.lr_dict = OrderedDict()  # it:lr
        self.is_in_finetuning_phase = False

        self.model = kwargs['model']
        self.optimizer = kwargs['optimizer']
        self.run_config = kwargs['config']
        self.callbacks = kwargs['callbacks']
        self.goal_sparsity = self.run_config['goal_sparsity']
        self.n_total_iterations = kwargs['n_total_iterations']

        self.metrics = {}

    def after_initialization(self):
        """Called after initialization of the strategy"""
        self.parameters_to_prune = [(module, 'weight') for name, module in self.model.named_modules() if
                                    hasattr(module, 'weight')
                                    and not isinstance(module.weight, type(None)) and not isinstance(module,
                                                                                                     torch.nn.BatchNorm2d)]
        self.n_prunable_parameters = sum(
            getattr(module, param_type).numel() for module, param_type in self.parameters_to_prune)

    @torch.no_grad()
    def start_forward_mode(self, **kwargs):
        """Function to be called before Forward step."""
        pass

    @torch.no_grad()
    def end_forward_mode(self, **kwargs):
        """Function to be called after Forward step."""
        pass

    @torch.no_grad()
    def adjust_train_sampler(self, **kwargs):
        """Called before each training epoch, should return None if not modified."""
        return None

    @torch.no_grad()
    def before_backward(self, **kwargs):
        """Function to be called after Forward step. Should return loss also if it is not modified."""
        return kwargs['loss']

    @torch.no_grad()
    def during_training(self, **kwargs):
        """Function to be called after loss.backward() and before optimizer.step, e.g. to mask gradients."""
        pass

    @torch.no_grad()
    def after_training_iteration(self, **kwargs):
        """Called after each training iteration"""
        if not self.is_in_finetuning_phase:
            self.lr_dict[kwargs['it']] = kwargs['lr']

    def at_train_begin(self):
        """Called before training begins"""
        pass

    def adjust_train_target(self, **kwargs):
        """Modifies the train targets"""
        return kwargs['y_target']

    def adjust_loss_fn(self, **kwargs):
        """Modifies the default loss criterion FUNCTION"""
        return kwargs['loss_criterion']

    def modify_loss(self, **kwargs):
        """Makes changes to the loss, e.g. for knowledge distillation."""
        return kwargs['loss']

    def at_epoch_start(self, **kwargs):
        """Called before the epoch starts"""
        pass

    def at_epoch_end(self, **kwargs):
        """Called at epoch end"""
        pass

    def at_train_end(self, **kwargs):
        """Called at the end of training"""
        pass

    def final(self):
        self.make_pruning_permanent()

    @torch.no_grad()
    def pruning_step(self, pruning_sparsity, only_save_mask=False, compute_from_scratch=False):
        if compute_from_scratch:
            # We have to revert to weight_orig and then compute the mask
            for module, param_type in self.parameters_to_prune:
                if prune.is_pruned(module):
                    # Enforce the equivalence of weight_orig and weight
                    orig = getattr(module, param_type + "_orig").detach().clone()
                    prune.remove(module, param_type)
                    p = getattr(module, param_type)
                    p.copy_(orig)
                    del orig
        elif only_save_mask and len(self.masks) > 0:
            for module, param_type in self.parameters_to_prune:
                if (module, param_type) in self.masks:
                    prune.custom_from_mask(module, param_type, self.masks[(module, param_type)])

        if self.run_config['pruning_selector'] is not None and self.run_config['pruning_selector'] == 'uniform':
            # We prune each layer individually
            for module, param_type in self.parameters_to_prune:
                prune.l1_unstructured(module, name=param_type, amount=pruning_sparsity)
        else:
            # Default: prune globally
            prune.global_unstructured(
                self.parameters_to_prune,
                pruning_method=self.get_pruning_method(),
                amount=pruning_sparsity,
            )

        self.masks = dict()  # Stays empty if we use regular pruning
        if only_save_mask:
            for module, param_type in self.parameters_to_prune:
                if prune.is_pruned(module):
                    # Save the mask
                    mask = getattr(module, param_type + '_mask')
                    self.masks[(module, param_type)] = mask.detach().clone()
                    setattr(module, param_type + '_mask', torch.ones_like(mask))
                    # Remove (i.e. make permanent) the reparameterization
                    prune.remove(module=module, name=param_type)
                    # Delete the temporary mask to free memory
                    del mask

    def enforce_prunedness(self):
        """
        Makes the pruning permant, i.e. set the pruned weights to zero, than reinitialize from the same mask
        This ensures that we can actually work (i.e. LMO, rescale computation) with the parameters
        Important: For this to work we require that pruned weights stay zero in weight_orig over training
        hence training, projecting etc should not modify (pruned) 0 weights in weight_orig
        """
        for module, param_type in self.parameters_to_prune:
            if prune.is_pruned(module):
                # Save the mask
                mask = getattr(module, param_type + '_mask')
                # Remove (i.e. make permanent) the reparameterization
                prune.remove(module=module, name=param_type)
                # Reinitialize the pruning
                prune.custom_from_mask(module=module, name=param_type, mask=mask)
                # Delete the temporary mask to free memory
                del mask

    def prune_momentum(self):
        opt_state = self.optimizer.state
        for module, param_type in self.parameters_to_prune:
            if prune.is_pruned(module):
                # Enforce the prunedness of momentum buffer
                param_state = opt_state[getattr(module, param_type + "_orig")]
                if 'momentum_buffer' in param_state:
                    mask = getattr(module, param_type + "_mask")
                    param_state['momentum_buffer'] *= mask.to(dtype=param_state['momentum_buffer'].dtype)

    def get_pruning_method(self):
        raise NotImplementedError("Dense has no pruning method, this must be implemented in each child class.")

    @torch.no_grad()
    def make_pruning_permanent(self):
        """Makes the pruning permanent and removes the pruning hooks"""
        # Note: this does not remove the pruning itself, but rather makes it permanent
        if len(self.masks) == 0:
            for module, param_type in self.parameters_to_prune:
                if prune.is_pruned(module):
                    prune.remove(module, param_type)
        else:
            for module, param_type in self.masks:
                # Get the mask
                mask = self.masks[(module, param_type)]

                # Apply the mask
                orig = getattr(module, param_type)
                orig *= mask
            self.masks = dict()

    def set_to_finetuning_phase(self):
        self.is_in_finetuning_phase = True

    def get_strategy_metrics(self):
        return {metricName:metric.compute() for metricName, metric in self.metrics.items()}

    def reset_strategy_metrics(self):
        for metric in self.metrics.values():
            metric.reset()


#### Pruning stable (UNSTRUCTURED) strategies
class GSM(Dense):
    """Global Sparse Momentum as by Ding et al. 2019"""
    required_params = ['goal_sparsity', 'gsm_delay']

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.n_remaining_params = None  # Number of parameters to remain after pruning
        self.gsm_delay = self.run_config['gsm_delay'] or 0
        self.delay_it = int(self.gsm_delay * self.n_total_iterations)

    def after_initialization(self):
        super().after_initialization()
        # Compute n_remaining_params
        self.n_remaining_params = int((1 - self.goal_sparsity) * self.n_prunable_parameters)

    @torch.no_grad()
    def during_training(self, **kwargs) -> None:
        """Apply topk mask to the gradients"""
        if kwargs['trainIteration'] < self.delay_it:
            return

        param_list = [p for group in self.optimizer.param_groups
                      for p in group['params'] if p.grad is not None]
        # Get the vector
        grad_vector = torch.cat([self.saliency_criterion(p=p).view(-1) for p in param_list])
        grad_vector_shape = grad_vector.shape
        device = param_list[0].device
        top_indices = torch.topk(grad_vector, k=self.n_remaining_params).indices
        del grad_vector
        mask_vector = torch.zeros(grad_vector_shape, device=device)
        mask_vector[top_indices] = 1

        for p in param_list:
            numberOfElements = p.numel()
            partial_mask = mask_vector[:numberOfElements].view(p.shape)
            mask_vector = mask_vector[numberOfElements:]
            p.grad.mul_(partial_mask)  # Mask gradient

    def saliency_criterion(self, p):
        # Returns the saliency criterion for param p, i.e. torch.abs(p*p.grad)
        return torch.abs(p * p.grad)




#### Dense Base Class for Structured approaches
class StructDense(Dense):
    """Base class for filter pruning oriented dense training"""

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    def after_initialization(self):
        """Called after initialization of the strategy"""
        self.parameters_to_prune = [(module, 'weight') for name, module in self.model.named_modules() if
                                         hasattr(module, 'weight')
                                         and not isinstance(module.weight, type(None)) and isinstance(module,
                                                                                                      torch.nn.Conv2d)]
        self.n_prunable_parameters = sum(
            getattr(module, param_type).numel() for module, param_type in self.parameters_to_prune)

    @torch.no_grad()
    def pruning_step(self, pruning_sparsity, only_save_mask=False, compute_from_scratch=False, pruning_norm=1):
        if compute_from_scratch:
            # We have to revert to weight_orig and then compute the mask
            for module, param_type in self.parameters_to_prune:
                if prune.is_pruned(module):
                    # Enforce the equivalence of weight_orig and weight
                    orig = getattr(module, param_type + "_orig").detach().clone()
                    prune.remove(module, param_type)
                    p = getattr(module, param_type)
                    p.copy_(orig)
                    del orig
        elif only_save_mask and len(self.masks) > 0:
            for module, param_type in self.parameters_to_prune:
                if (module, param_type) in self.masks:
                    prune.custom_from_mask(module, param_type, self.masks[(module, param_type)])

        # We prune filters locally
        sys.stdout.write(f"\nPruning by l{pruning_norm} norm.")
        for module, param_type in self.parameters_to_prune:
            prune.ln_structured(module, param_type, pruning_sparsity, n=pruning_norm, dim=0)

        self.masks = dict()  # Stays empty if we use regular pruning
        if only_save_mask:
            for module, param_type in self.parameters_to_prune:
                if prune.is_pruned(module):
                    # Save the mask
                    mask = getattr(module, param_type + '_mask')
                    self.masks[(module, param_type)] = mask.detach().clone()
                    setattr(module, param_type + '_mask', torch.ones_like(mask))
                    # Remove (i.e. make permanent) the reparameterization
                    prune.remove(module=module, name=param_type)
                    # Delete the temporary mask to free memory
                    del mask

#### Pruning stable (STRUCTURED) strategies
class SparseFW(Dense):

    required_params = ['lmo', 'lmo_mode', 'lmo_ord', 'lmo_value', 'lmo_k', 'lmo_rescale', 'lmo_global', 'lmo_delay', 'lmo_nuc_method']

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

## Filter based
class SSL(StructDense):
    """Structured Sparsity Learning (SSL) as in Wen et al. (2016). Essentially a group penalty on the filters."""
    required_params = ['group_penalty']

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.group_penalty = self.run_config['group_penalty'] or 0
        assert self.group_penalty > 0

    # Important: no torch.no_grad
    def before_backward(self, **kwargs):
        """Add group penalty if needed. Needs to return modified loss"""
        loss = kwargs['loss']
        if self.group_penalty > 0:
            # Adds a weighted L2-penalty for each filter
            for module, param_type in self.parameters_to_prune:
                if hasattr(module, param_type + "_orig"):
                    p = getattr(module, param_type + "_orig")
                else:
                    p = getattr(module, param_type)
                loss = loss + self.group_penalty * \
                       torch.sum(torch.linalg.norm(p.flatten(start_dim=1), ord=2, dim=1))
        return loss

class GLT(StructDense):
    """Group Lasso Thresholding (GLT) approach by Alvarez et al. (2016). Implements a proximal step at the end of each epoch."""
    required_params = ['group_penalty', 'lasso_tradeoff']

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.group_penalty = self.run_config['group_penalty'] or 0
        self.lasso_tradeoff = self.run_config['lasso_tradeoff']
        assert self.group_penalty > 0
        assert 0 <= self.lasso_tradeoff <= 1

    def at_epoch_end(self, **kwargs):
        super().at_epoch_end(**kwargs)
        epoch = kwargs['epoch']
        if epoch == 0 or epoch == self.run_config['n_epochs']:
            return
        lr = float(self.optimizer.param_groups[0]['lr'])
        with torch.no_grad():
            for filter in self.parameters_to_prune:
                p = getattr(*filter)
                # S can be defined elementwise
                S = torch.sign(p)*torch.nn.functional.relu(torch.abs(p) - lr*self.lasso_tradeoff*self.group_penalty)
                # S_norm must be computed per filter
                S_norm = torch.linalg.norm(S.flatten(start_dim=1), ord=2, dim=1)
                P_l = sqrt(p[0].numel())

                factorPerFilter = torch.nn.functional.relu(1 - lr*(1-self.lasso_tradeoff)*self.group_penalty*P_l/S_norm)
                proxGD_result = S.flatten(start_dim=1)*factorPerFilter[:, None]
                p.copy_(proxGD_result.view(p.shape))


class ABFP(StructDense):
    """Auto-Balanced Filter Pruning as by Ding et al. (2018)"""
    required_params = ['group_penalty', 'abfp_k']   # penalty strength and fraction of filters per layer to allow growing

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.group_penalty = self.run_config['group_penalty'] or 0
        self.abfp_k = self.run_config['abfp_k']
        assert self.group_penalty > 0
        assert 0 <= self.abfp_k <= 1

        self.penaltyVecDict = {}

        self.eps = 1e-12 # To avoid numerical issues

    @torch.no_grad()
    def after_initialization(self):
        """Define the per-filter-penalties"""
        super().after_initialization()

        for filter in self.parameters_to_prune:
            p = getattr(*filter).flatten(start_dim=1)   # Shape: (n_filters, *)
            # Compute the threshold
            M = torch.linalg.norm(p, ord=1, dim=1)
            #top_indices = torch.topk(input=M, k=int(self.abfp_k*M.numel())).indices
            k = int(self.abfp_k*M.numel())
            k_smallest = M.numel() - k
            theta = torch.kthvalue(input=M, k=k_smallest).values
            # The increasing group Ri is all filters with l_1 norm >= theta, the rest the smaller ones
            penaltyVec = torch.zeros_like(M)
            penaltyVec[M < theta] = 1. + torch.log(theta / (M[M < theta] + self.eps))
            penaltyVec[M >= theta] = -1. - torch.log(M[M >= theta] / (theta + self.eps))

            self.penaltyVecDict[filter] = penaltyVec


    # Important: no torch.no_grad
    def before_backward(self, **kwargs):
        """Add group penalty if needed. Needs to return modified loss"""
        loss = kwargs['loss']
        if self.group_penalty > 0:
            # Adds a weighted L2-penalty for each filter
            Sp, Sr = 0., 0.
            for module, param_type in self.parameters_to_prune:
                if hasattr(module, param_type + "_orig"):
                    p = getattr(module, param_type + "_orig")
                else:
                    p = getattr(module, param_type)

                filter = (module, param_type)
                penaltyVec = self.penaltyVecDict[filter]
                l_2_filters = (p.flatten(start_dim=1)**2).sum(dim=1)
                Sp += (penaltyVec[penaltyVec >= 0]*l_2_filters[penaltyVec >= 0]).sum()
                Sr += (penaltyVec[penaltyVec < 0] * l_2_filters[penaltyVec < 0]).sum()


            with torch.no_grad():
                tau = -self.group_penalty * Sp/Sr
            loss = loss + self.group_penalty*Sp + tau*Sr    # This should result in 0 addition to the tensor, but thats exactly the way the did in the paper

        return loss

class SFP(StructDense):
    """Soft Filter Pruning as by He et al. (2018)"""
    required_params = ['sfp_k', 'sfp_start_epoch']

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

        self.sfp_k = self.run_config['sfp_k']
        assert 0 < self.sfp_k < 1

        self.sfp_start_epoch = self.run_config['sfp_start_epoch']
        assert 0 <= self.sfp_start_epoch < self.run_config['n_epochs']

    def at_epoch_end(self, **kwargs):
        super().at_epoch_end(**kwargs)
        epoch = kwargs['epoch']
        if epoch == 0 or epoch == self.run_config['n_epochs'] or epoch < self.sfp_start_epoch:  # In the last epoch we do not prune, this is done by retraining
            return
        with torch.no_grad():
            for filter in self.parameters_to_prune:
                p = getattr(*filter)

                # Prune the filters by L2 norm
                n_filters = p.shape[0]
                filters_to_keep = int(self.sfp_k * n_filters)

                filter_norms = torch.norm(p.flatten(start_dim=1), p=2, dim=1)
                top_indices = torch.topk(filter_norms, k=filters_to_keep).indices
                mask = torch.zeros_like(p)
                mask.flatten(start_dim=1)[top_indices] = 1
                prune.custom_from_mask(module=filter[0], name=filter[1], mask=mask)

            # Prune momentum, then make the pruning permanent
            self.prune_momentum()
            self.make_pruning_permanent()


## Spectrum based
class NUC(StructDense):
    """Nuclear norm penalization followed by SVD, roughly similar to Denton et al."""
    required_params = ['group_penalty']

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.group_penalty = self.run_config['group_penalty'] or 0
        assert self.group_penalty > 0

    # Important: no torch.no_grad
    def before_backward(self, **kwargs):
        """Add group penalty if needed. Needs to return modified loss"""
        loss = kwargs['loss']
        if self.group_penalty > 0:
            # Add nuclear norm regularization to Conv layers, but reshape them first
            for module, param_type in self.parameters_to_prune:
                if hasattr(module, param_type + "_orig"):
                    p = getattr(module, param_type + "_orig")
                else:
                    p = getattr(module, param_type)
                loss = loss + self.group_penalty * torch.linalg.norm(p.flatten(start_dim=1), ord='nuc')
        return loss

class SVDEnergy(StructDense):
    """Nuclear Thresholding approach by Alvarez et al. (2017). Implements a proximal step at the end of each epoch."""
    required_params = ['group_penalty']

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.group_penalty = self.run_config['group_penalty'] or 0
        assert self.group_penalty > 0

        self.prox_operator = ProximalOperator.svd_soft_thresholding(threshold=self.group_penalty)

    def at_epoch_end(self, **kwargs):
        super().at_epoch_end(**kwargs)
        epoch = kwargs['epoch']
        if epoch == 0 or epoch == self.run_config['n_epochs']:
            return
        lr = float(self.optimizer.param_groups[0]['lr'])
        with torch.no_grad():
            for filter in self.parameters_to_prune:
                p = getattr(*filter)
                # Apply the proximal operator
                p.copy_(self.prox_operator(x=p, lr=lr))

class SVDEnergyIteration(StructDense):
    """Nuclear Thresholding approach by Alvarez et al. (2017). Implements a proximal step after each iteration."""
    required_params = ['group_penalty']

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.group_penalty = self.run_config['group_penalty'] or 0
        assert self.group_penalty > 0

        self.prox_operator = ProximalOperator.svd_soft_thresholding(threshold=self.group_penalty)

    def after_training_iteration(self, **kwargs):
        super().after_training_iteration(**kwargs)
        lr = float(self.optimizer.param_groups[0]['lr'])
        with torch.no_grad():
            for filter in self.parameters_to_prune:
                p = getattr(*filter)
                # Apply the proximal operator
                p.copy_(self.prox_operator(x=p, lr=lr))
